{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#  图像分类模型通道剪裁-快速开始\n",
    "\n",
    "该教程以图像分类模型MobileNetV1为例，说明如何快速使用[PaddleSlim的卷积通道剪裁接口]()。\n",
    "该示例包含以下步骤：\n",
    "\n",
    "1. 导入依赖\n",
    "2. 构建模型\n",
    "3. 剪裁\n",
    "4. 训练剪裁后的模型\n",
    "\n",
    "以下章节依次次介绍每个步骤的内容。\n",
    "\n",
    "## 1. 导入依赖\n",
    "\n",
    "PaddleSlim依赖Paddle1.7版本，请确认已正确安装Paddle，然后按以下方式导入Paddle和PaddleSlim:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "import paddle\n",
    "import paddle.fluid as fluid\n",
    "import paddleslim as slim"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. 构建网络\n",
    "\n",
    "该章节构造一个用于对MNIST数据进行分类的分类模型，选用`MobileNetV1`，并将输入大小设置为`[1, 28, 28]`，输出类别数为10。\n",
    "为了方便展示示例，我们在`paddleslim.models`下预定义了用于构建分类模型的方法，执行以下代码构建分类模型："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "exe, train_program, val_program, inputs, outputs = slim.models.image_classification(\"MobileNet\", [1, 28, 28], 10, use_gpu=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    ">注意：paddleslim.models下的API并非PaddleSlim常规API，是为了简化示例而封装预定义的一系列方法，比如：模型结构的定义、Program的构建等。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. 剪裁卷积层通道\n",
    "\n",
    "### 3.1 计算剪裁之前的FLOPs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "FLOPs: 10907072.0\n"
     ]
    }
   ],
   "source": [
    "FLOPs = slim.analysis.flops(train_program)\n",
    "print(\"FLOPs: {}\".format(FLOPs))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.2 剪裁\n",
    "\n",
    "我们这里对参数名为`conv2_1_sep_weights`和`conv2_2_sep_weights`的卷积层进行剪裁，分别剪掉20%和30%的通道数。\n",
    "代码如下所示："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "pruner = slim.prune.Pruner()\n",
    "pruned_program, _, _ = pruner.prune(\n",
    "        train_program,\n",
    "        fluid.global_scope(),\n",
    "        params=[\"conv2_1_sep_weights\", \"conv2_2_sep_weights\"],\n",
    "        ratios=[0.33] * 2,\n",
    "        place=fluid.CPUPlace())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "以上操作会修改`train_program`中对应卷积层参数的定义，同时对`fluid.global_scope()`中存储的参数数组进行裁剪。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.3 计算剪裁之后的FLOPs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "FLOPs: 10907072.0\n"
     ]
    }
   ],
   "source": [
    "FLOPs = paddleslim.analysis.flops(train_program)\n",
    "print(\"FLOPs: {}\".format(FLOPs))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. 训练剪裁后的模型\n",
    "\n",
    "### 4.1 定义输入数据\n",
    "\n",
    "为了快速执行该示例，我们选取简单的MNIST数据，Paddle框架的`paddle.dataset.mnist`包定义了MNIST数据的下载和读取。\n",
    "代码如下："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "import paddle.dataset.mnist as reader\n",
    "train_reader = paddle.batch(\n",
    "        reader.train(), batch_size=128, drop_last=True)\n",
    "train_feeder = fluid.DataFeeder(inputs, fluid.CPUPlace())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.2 执行训练\n",
    "以下代码执行了一个`epoch`的训练："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.1484375] [0.4921875] [2.6727316]\n",
      "[0.125] [0.546875] [2.6547904]\n",
      "[0.125] [0.5546875] [2.795205]\n",
      "[0.1171875] [0.578125] [2.8561475]\n",
      "[0.1875] [0.59375] [2.470603]\n",
      "[0.1796875] [0.578125] [2.8031898]\n",
      "[0.1484375] [0.6015625] [2.7530417]\n",
      "[0.1953125] [0.640625] [2.711596]\n",
      "[0.125] [0.59375] [2.8637898]\n",
      "[0.1796875] [0.53125] [2.9473038]\n",
      "[0.25] [0.671875] [2.3943179]\n",
      "[0.25] [0.6953125] [2.632146]\n",
      "[0.2578125] [0.7265625] [2.723265]\n",
      "[0.359375] [0.765625] [2.4263484]\n",
      "[0.3828125] [0.8203125] [2.226284]\n",
      "[0.421875] [0.8203125] [1.8042578]\n",
      "[0.4765625] [0.890625] [1.6841211]\n",
      "[0.53125] [0.8671875] [2.1971617]\n",
      "[0.5546875] [0.8984375] [1.5361531]\n",
      "[0.53125] [0.890625] [1.7211896]\n",
      "[0.5078125] [0.8984375] [1.6586945]\n",
      "[0.53125] [0.9140625] [1.8980236]\n",
      "[0.546875] [0.9453125] [1.5279069]\n",
      "[0.5234375] [0.8828125] [1.7356458]\n",
      "[0.6015625] [0.9765625] [1.0375824]\n",
      "[0.5546875] [0.921875] [1.639497]\n",
      "[0.6015625] [0.9375] [1.5469061]\n",
      "[0.578125] [0.96875] [1.3573356]\n",
      "[0.65625] [0.9453125] [1.3787829]\n",
      "[0.640625] [0.9765625] [0.9946856]\n",
      "[0.65625] [0.96875] [1.1651027]\n",
      "[0.625] [0.984375] [1.0487883]\n",
      "[0.7265625] [0.9609375] [1.2526855]\n",
      "[0.7265625] [0.9765625] [1.2954011]\n",
      "[0.65625] [0.96875] [1.1181556]\n",
      "[0.71875] [0.9765625] [0.97891223]\n",
      "[0.640625] [0.9609375] [1.2135172]\n",
      "[0.7265625] [0.9921875] [0.8950747]\n",
      "[0.7578125] [0.96875] [1.0864108]\n",
      "[0.734375] [0.9921875] [0.8392239]\n",
      "[0.796875] [0.9609375] [0.7012155]\n",
      "[0.7734375] [0.9765625] [0.7409136]\n",
      "[0.8046875] [0.984375] [0.6108341]\n",
      "[0.796875] [0.9765625] [0.63867176]\n",
      "[0.7734375] [0.984375] [0.64099216]\n",
      "[0.7578125] [0.9453125] [0.83827704]\n",
      "[0.8046875] [0.9921875] [0.5311729]\n",
      "[0.8984375] [0.9921875] [0.36445504]\n",
      "[0.859375] [0.9921875] [0.40577835]\n",
      "[0.8125] [0.9765625] [0.64629185]\n",
      "[0.84375] [1.] [0.38400555]\n",
      "[0.890625] [0.9765625] [0.45866236]\n",
      "[0.8828125] [0.9921875] [0.3711415]\n",
      "[0.7578125] [0.9921875] [0.6650479]\n",
      "[0.7578125] [0.984375] [0.9030752]\n",
      "[0.8671875] [0.9921875] [0.3678714]\n",
      "[0.7421875] [0.9765625] [0.7424855]\n",
      "[0.7890625] [1.] [0.6212543]\n",
      "[0.8359375] [1.] [0.58529043]\n",
      "[0.8203125] [0.96875] [0.5860813]\n",
      "[0.8671875] [0.9921875] [0.415236]\n",
      "[0.8125] [1.] [0.60501564]\n",
      "[0.796875] [0.9765625] [0.60677457]\n",
      "[0.8515625] [1.] [0.5338207]\n",
      "[0.8046875] [0.9921875] [0.54180473]\n",
      "[0.875] [0.9921875] [0.7293667]\n",
      "[0.84375] [0.9765625] [0.5581689]\n",
      "[0.8359375] [1.] [0.50712734]\n",
      "[0.8671875] [0.9921875] [0.55217856]\n",
      "[0.765625] [0.96875] [0.8076792]\n",
      "[0.953125] [1.] [0.17031987]\n",
      "[0.890625] [0.9921875] [0.42383268]\n",
      "[0.828125] [0.9765625] [0.49300486]\n",
      "[0.8671875] [0.96875] [0.57985115]\n",
      "[0.8515625] [1.] [0.4901033]\n",
      "[0.921875] [1.] [0.34583277]\n",
      "[0.8984375] [0.984375] [0.41139168]\n",
      "[0.9296875] [1.] [0.20420414]\n",
      "[0.921875] [0.984375] [0.24322833]\n",
      "[0.921875] [0.9921875] [0.30570173]\n",
      "[0.875] [0.9921875] [0.3866225]\n",
      "[0.9140625] [0.9921875] [0.20813875]\n",
      "[0.9140625] [1.] [0.17933217]\n",
      "[0.8984375] [0.9921875] [0.32508463]\n",
      "[0.9375] [1.] [0.24799153]\n",
      "[0.9140625] [1.] [0.26146784]\n",
      "[0.90625] [1.] [0.24672262]\n",
      "[0.8828125] [1.] [0.34094217]\n",
      "[0.90625] [1.] [0.2964819]\n",
      "[0.9296875] [1.] [0.18237087]\n",
      "[0.84375] [1.] [0.7182543]\n",
      "[0.8671875] [0.984375] [0.508474]\n",
      "[0.8828125] [0.9921875] [0.367172]\n",
      "[0.9453125] [1.] [0.2366665]\n",
      "[0.9375] [1.] [0.12494276]\n",
      "[0.8984375] [1.] [0.3395289]\n",
      "[0.890625] [0.984375] [0.30877113]\n",
      "[0.90625] [1.] [0.29763448]\n",
      "[0.8828125] [0.984375] [0.4845504]\n",
      "[0.8515625] [1.] [0.45548072]\n",
      "[0.8828125] [1.] [0.33331633]\n",
      "[0.90625] [1.] [0.4024018]\n",
      "[0.890625] [0.984375] [0.73405886]\n",
      "[0.9609375] [0.9921875] [0.15409982]\n",
      "[0.9140625] [0.984375] [0.37103674]\n",
      "[0.953125] [1.] [0.17628372]\n",
      "[0.890625] [1.] [0.36522508]\n",
      "[0.8828125] [1.] [0.407708]\n",
      "[0.9375] [0.984375] [0.25090045]\n",
      "[0.890625] [0.984375] [0.35742313]\n",
      "[0.921875] [0.9921875] [0.2751101]\n",
      "[0.890625] [0.984375] [0.43053097]\n",
      "[0.875] [0.9921875] [0.34412643]\n",
      "[0.90625] [1.] [0.35595697]\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-21-92f72657bddc>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtrain_reader\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m     \u001b[0macc1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0macc5\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mexe\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpruned_program\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeed\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtrain_feeder\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfeed\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0moutputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      3\u001b[0m     \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0macc1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0macc5\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.5/dist-packages/paddle/fluid/executor.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self, program, feed, fetch_list, feed_var_name, fetch_var_name, scope, return_numpy, use_program_cache)\u001b[0m\n\u001b[1;32m    776\u001b[0m                 \u001b[0mscope\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mscope\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    777\u001b[0m                 \u001b[0mreturn_numpy\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mreturn_numpy\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 778\u001b[0;31m                 use_program_cache=use_program_cache)\n\u001b[0m\u001b[1;32m    779\u001b[0m         \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    780\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcore\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mEOFException\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.5/dist-packages/paddle/fluid/executor.py\u001b[0m in \u001b[0;36m_run_impl\u001b[0;34m(self, program, feed, fetch_list, feed_var_name, fetch_var_name, scope, return_numpy, use_program_cache)\u001b[0m\n\u001b[1;32m    829\u001b[0m                 \u001b[0mscope\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mscope\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    830\u001b[0m                 \u001b[0mreturn_numpy\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mreturn_numpy\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 831\u001b[0;31m                 use_program_cache=use_program_cache)\n\u001b[0m\u001b[1;32m    832\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    833\u001b[0m         \u001b[0mprogram\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mscope\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mplace\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.5/dist-packages/paddle/fluid/executor.py\u001b[0m in \u001b[0;36m_run_program\u001b[0;34m(self, program, feed, fetch_list, feed_var_name, fetch_var_name, scope, return_numpy, use_program_cache)\u001b[0m\n\u001b[1;32m    903\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0muse_program_cache\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    904\u001b[0m             self._default_executor.run(program.desc, scope, 0, True, True,\n\u001b[0;32m--> 905\u001b[0;31m                                        fetch_var_name)\n\u001b[0m\u001b[1;32m    906\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    907\u001b[0m             self._default_executor.run_prepared_ctx(ctx, scope, False, False,\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "for data in train_reader():\n",
    "    acc1, acc5, loss = exe.run(pruned_program, feed=train_feeder.feed(data), fetch_list=outputs)\n",
    "    print(acc1, acc5, loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
